{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0_xya1nyDHfY",
   "metadata": {
    "id": "0_xya1nyDHfY"
   },
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "l62zIRRSBy_R",
   "metadata": {
    "id": "l62zIRRSBy_R"
   },
   "source": [
    "# Converting Llama 2 to Llama 3.2 From Scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aFmxTQbwCUMl",
   "metadata": {
    "id": "aFmxTQbwCUMl"
   },
   "source": [
    "- This is a follow-up notebook to [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb), converting Meta AI's Llama 2 architecture model step by step to Llama 3, Llama 3.1, and Llama 3.2\n",
    "- The explanations are purposefully kept minimal in this notebook so as not to bloat it unnecessarily and focus on the main code\n",
    "- For more information about the architectures, please see the Llama 2 and Llama 3 papers\n",
    " - [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)\n",
    " - [The Llama 3 Herd of Models](https://arxiv.org/abs/2407.21783)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ohhMKUWvGm9z",
   "metadata": {
    "id": "ohhMKUWvGm9z"
   },
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt2-to-llama2-llama3.webp?1\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ws0wsUzwLH2k",
   "metadata": {
    "id": "ws0wsUzwLH2k"
   },
   "outputs": [],
   "source": [
    "# pip install -r requirements-extra.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "JBpQwU89ETA1",
   "metadata": {
    "id": "JBpQwU89ETA1"
   },
   "source": [
    "- Packages that are being used in this notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
    "outputId": "e3d3d4b6-ee63-4e28-d794-e8b0bdd931fd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "blobfile version: 3.0.0\n",
      "huggingface_hub version: 0.24.7\n",
      "tiktoken version: 0.8.0\n",
      "torch version: 2.4.1+cu121\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "pkgs = [\n",
    "    \"blobfile\",         # to download pretrained weights\n",
    "    \"huggingface_hub\",  # to download pretrained weights\n",
    "    \"tiktoken\",         # to implement the tokenizer\n",
    "    \"torch\",            # to implement the model\n",
    "]\n",
    "for p in pkgs:\n",
    "    print(f\"{p} version: {version(p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UJJneXpTEg4W",
   "metadata": {
    "id": "UJJneXpTEg4W"
   },
   "source": [
    "&nbsp;\n",
    "# 1. Convert the Llama model implementation step by step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "v1zpfX2GHBKa",
   "metadata": {
    "id": "v1zpfX2GHBKa"
   },
   "source": [
    "- If you are new to implementing LLM architectures, I recommend starting with [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb), which walks you through the implementation of the original GPT architecture step by step\n",
    "- The [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb) then implements the Llama-specific components, such as RMSNorm layers, SiLU and SwiGLU activations, RoPE (rotary position embeddings), and the SentencePiece tokenizer\n",
    "- This notebook takes the Llama 2 architecture and transforms it into Llama 3 architecture by\n",
    "    1. modifying the rotary embeddings\n",
    "    2. implementing grouped-query attention\n",
    "    3. and using a customized version of the GPT-4 tokenizer\n",
    "- Later, we then load the original Llama 3 weights shared by Meta AI into the architecture"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41",
   "metadata": {
    "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41"
   },
   "source": [
    "&nbsp;\n",
    "## 1.1 Reusing Llama 2 components"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dgDhJGJ6xR4e",
   "metadata": {
    "id": "dgDhJGJ6xR4e"
   },
   "source": [
    "- Llama 2 is actually quite similar to Llama 3, as mentioned above and illustrated in the figure at the top of this notebook\n",
    "- This means that we can import several building blocks from the [Llama 2 notebook](./converting-gpt-to-llama2.ipynb) using the following code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0",
   "metadata": {
    "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import io\n",
    "import nbformat\n",
    "import types\n",
    "\n",
    "def import_from_notebook():\n",
    "    def import_definitions_from_notebook(fullname, names):\n",
    "        current_dir = os.getcwd()\n",
    "        path = os.path.join(current_dir, fullname + \".ipynb\")\n",
    "        path = os.path.normpath(path)\n",
    "\n",
    "        # Load the notebook\n",
    "        if not os.path.exists(path):\n",
    "            raise FileNotFoundError(f\"Notebook file not found at: {path}\")\n",
    "\n",
    "        with io.open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "            nb = nbformat.read(f, as_version=4)\n",
    "\n",
    "        # Create a module to store the imported functions and classes\n",
    "        mod = types.ModuleType(fullname)\n",
    "        sys.modules[fullname] = mod\n",
    "\n",
    "        # Go through the notebook cells and only execute function or class definitions\n",
    "        for cell in nb.cells:\n",
    "            if cell.cell_type == \"code\":\n",
    "                cell_code = cell.source\n",
    "                for name in names:\n",
    "                    # Check for function or class definitions\n",
    "                    if f\"def {name}\" in cell_code or f\"class {name}\" in cell_code:\n",
    "                        exec(cell_code, mod.__dict__)\n",
    "        return mod\n",
    "\n",
    "    fullname = \"converting-gpt-to-llama2\"\n",
    "    names = [\"precompute_rope_params\", \"compute_rope\", \"SiLU\", \"FeedForward\", \"RMSNorm\", \"MultiHeadAttention\"]\n",
    "\n",
    "    return import_definitions_from_notebook(fullname, names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d546032d-fce4-47cf-8d0e-682b78b21c61",
   "metadata": {
    "id": "d546032d-fce4-47cf-8d0e-682b78b21c61"
   },
   "outputs": [],
   "source": [
    "imported_module = import_from_notebook()\n",
    "\n",
    "# We need to redefine precompute_rope_params\n",
    "# precompute_rope_params = getattr(imported_module, \"precompute_rope_params\", None)\n",
    "compute_rope = getattr(imported_module, \"compute_rope\", None)\n",
    "SiLU = getattr(imported_module, \"SiLU\", None)\n",
    "FeedForward = getattr(imported_module, \"FeedForward\", None)\n",
    "RMSNorm = getattr(imported_module, \"RMSNorm\", None)\n",
    "\n",
    "# MultiHeadAttention only for comparison purposes\n",
    "MultiHeadAttention = getattr(imported_module, \"MultiHeadAttention\", None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f",
   "metadata": {
    "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f"
   },
   "source": [
    "&nbsp;\n",
    "## 1.2 Modified RoPE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "m9_oDcHCx8VI",
   "metadata": {
    "id": "m9_oDcHCx8VI"
   },
   "source": [
    "- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
    "- There are some subtle differences in the RoPE settings, though\n",
    " - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n",
    " - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 500,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
    "\n",
    "$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{-2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n",
    "\n",
    "- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n",
    "- Increasing the base from 10,000 to 500,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n",
    "- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6Upl109OOAcu",
   "metadata": {
    "id": "6Upl109OOAcu"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
    "\n",
    "    # Compute the inverse frequencies\n",
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
    "\n",
    "    ################################ NEW ###############################################\n",
    "    # Frequency adjustments\n",
    "    if freq_config is not None:\n",
    "        low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n",
    "        high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n",
    "\n",
    "        wavelen = 2 * torch.pi / inv_freq\n",
    "\n",
    "        inv_freq_llama = torch.where(\n",
    "            wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n",
    "        )\n",
    "\n",
    "        smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n",
    "            freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n",
    "        )\n",
    "\n",
    "        smoothed_inv_freq = (\n",
    "            (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n",
    "        )\n",
    "\n",
    "        is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n",
    "        inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n",
    "        inv_freq = inv_freq_llama\n",
    "    ####################################################################################\n",
    "\n",
    "\n",
    "    # Generate position indices\n",
    "    positions = torch.arange(context_length)\n",
    "\n",
    "    # Compute the angles\n",
    "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
    "\n",
    "    # Expand angles to match the head_dim\n",
    "    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)\n",
    "\n",
    "    # Precompute sine and cosine\n",
    "    cos = torch.cos(angles)\n",
    "    sin = torch.sin(angles)\n",
    "\n",
    "    return cos, sin"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "jJBvO0YMJBXR",
   "metadata": {
    "id": "jJBvO0YMJBXR"
   },
   "source": [
    "- To summarize, what's new so far for Llama 3 compared to Llama 2 are the context length and theta base parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1",
   "metadata": {
    "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1"
   },
   "outputs": [],
   "source": [
    "# Instantiate RoPE parameters\n",
    "\n",
    "llama_2_context_len = 4096\n",
    "llama_3_context_len = 8192\n",
    "\n",
    "llama_2_theta_base = 10_000\n",
    "llama_3_theta_base = 500_000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "_V8v6i7MJItU",
   "metadata": {
    "id": "_V8v6i7MJItU"
   },
   "source": [
    "- The usage remains the same as before in Llama 2:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b",
   "metadata": {
    "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b"
   },
   "outputs": [],
   "source": [
    "# Settings\n",
    "batch_size = 2\n",
    "num_heads = 4\n",
    "head_dim = 16\n",
    "\n",
    "# Instantiate RoPE parameters\n",
    "cos, sin = precompute_rope_params(\n",
    "    head_dim=head_dim,\n",
    "    theta_base=llama_3_theta_base,\n",
    "    context_length=llama_3_context_len\n",
    ")\n",
    "\n",
    "# Dummy query and key tensors\n",
    "torch.manual_seed(123)\n",
    "queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
    "keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)\n",
    "\n",
    "# Apply rotary position embeddings\n",
    "queries_rot = compute_rope(queries, cos, sin)\n",
    "keys_rot = compute_rope(keys, cos, sin)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a",
   "metadata": {
    "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a"
   },
   "source": [
    "&nbsp;\n",
    "## 1.3 Grouped-query attention"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc",
   "metadata": {
    "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc"
   },
   "source": [
    "- In this section, we replace multi-head attention (MHA) with an alternative mechanism called grouped-query attention (GQA)\n",
    "- In short, one can think of GQA as a more compute- and parameter-efficient version of MHA\n",
    "- In GQA, we reduce the number of key and value projections by sharing them among multiple attention heads\n",
    "- Each attention head still has its unique query, but these queries attend to the same group of keys and values\n",
    "- Below is an illustration of GQA with 2 key-value-groups (kv-groups):\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/grouped-query-attention.webp\" width=\"500px\">\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "perAYa2R_KW2",
   "metadata": {
    "id": "perAYa2R_KW2"
   },
   "source": [
    "- The main idea behind GQA is to reduce the number of unique query groups that attend to the key-value pairs, reducing the size of some of the matrix multiplications and the number of parameters in MHA without significantly reducing modeling performance\n",
    "- The GQA code is very similar to MHA (I highlighted the changes below via the \"NEW\" sections)\n",
    "- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "842aa71a-4659-424e-8830-392bd6ae86af",
   "metadata": {},
   "source": [
    "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9b12e674-ef08-4dd7-8843-615b65b39c91",
   "metadata": {
    "id": "9b12e674-ef08-4dd7-8843-615b65b39c91"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "############################# NEW  #############################\n",
    "class SharedBuffers:\n",
    "    _buffers = {}\n",
    "\n",
    "    @staticmethod\n",
    "    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
    "        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
    "\n",
    "        if key not in SharedBuffers._buffers:\n",
    "            # Create or fetch the buffers\n",
    "            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
    "            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
    "            if dtype is not None:\n",
    "                cos = cos.to(dtype)\n",
    "                sin = sin.to(dtype)\n",
    "            SharedBuffers._buffers[key] = (mask, cos, sin)\n",
    "\n",
    "        return SharedBuffers._buffers[key]\n",
    "############################# NEW  #############################\n",
    "\n",
    "\n",
    "class GroupedQueryAttention(nn.Module):\n",
    "    def __init__(\n",
    "            self, d_in, d_out, context_length, num_heads,\n",
    "            num_kv_groups,       # NEW\n",
    "            rope_base=10_000,    # NEW\n",
    "            rope_config=None,    # NEW\n",
    "            dtype=None\n",
    "        ):\n",
    "        super().__init__()\n",
    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"  # NEW\n",
    "\n",
    "        self.d_out = d_out\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = d_out // num_heads\n",
    "\n",
    "        ############################# NEW  #############################\n",
    "        # self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        # self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
    "        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n",
    "        self.num_kv_groups = num_kv_groups\n",
    "        self.group_size = num_heads // num_kv_groups\n",
    "        ################################################################\n",
    "\n",
    "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
    "\n",
    "        ############################# NEW  #############################\n",
    "        # Fetch buffers using SharedBuffers\n",
    "        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
    "        ############################# NEW  #############################\n",
    "        \n",
    "        self.register_buffer(\"mask\", mask)\n",
    "        self.register_buffer(\"cos\", cos)\n",
    "        self.register_buffer(\"sin\", sin)\n",
    "\n",
    "    def forward(self, x):\n",
    "        b, num_tokens, d_in = x.shape\n",
    "\n",
    "        queries = self.W_query(x)  # Shape: (b, num_tokens, d_out)\n",
    "        keys = self.W_key(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
    "        values = self.W_value(x)  # Shape: (b, num_tokens, num_kv_groups * head_dim)\n",
    "\n",
    "        # Reshape queries, keys, and values\n",
    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "\n",
    "        ##################### NEW  #####################\n",
    "        # keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "        # values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
    "        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n",
    "        ################################################\n",
    "\n",
    "        # Transpose keys, values, and queries\n",
    "        keys = keys.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
    "        values = values.transpose(1, 2)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
    "        queries = queries.transpose(1, 2)  # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
    "\n",
    "        # Apply RoPE\n",
    "        keys = compute_rope(keys, self.cos, self.sin)\n",
    "        queries = compute_rope(queries, self.cos, self.sin)\n",
    "\n",
    "        ##################### NEW  #####################\n",
    "        # Expand keys and values to match the number of heads\n",
    "        # Shape: (b, num_heads, num_tokens, head_dim)\n",
    "\n",
    "        keys = keys.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
    "        values = values.repeat_interleave(self.group_size, dim=1)  # Shape: (b, num_heads, num_tokens, head_dim)\n",
    "        # For example, before repeat_interleave along dim=1 (query groups):\n",
    "        #   [K1, K2]\n",
    "        # After repeat_interleave (each query group is repeated group_size times):\n",
    "        #   [K1, K1, K2, K2]\n",
    "        # If we used regular repeat instead of repeat_interleave, we'd get:\n",
    "        #   [K1, K2, K1, K2]\n",
    "        ################################################\n",
    "\n",
    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
    "        # Shape: (b, num_heads, num_tokens, num_tokens)\n",
    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
    "\n",
    "        # Original mask truncated to the number of tokens and converted to boolean\n",
    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
    "\n",
    "        # Use the mask to fill attention scores\n",
    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
    "\n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
    "        assert keys.shape[-1] == self.head_dim\n",
    "\n",
    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
    "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
    "\n",
    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
    "        context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
    "        context_vec = self.out_proj(context_vec)  # optional projection\n",
    "\n",
    "        return context_vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "roAXSwJs9hR8",
   "metadata": {
    "id": "roAXSwJs9hR8"
   },
   "source": [
    "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b4b8f085-349e-4674-a3f0-78fde0664fac",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "b4b8f085-349e-4674-a3f0-78fde0664fac",
    "outputId": "9da09d72-43b1-45af-d46f-6928ea4af33a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W_key: torch.Size([4096, 4096])\n",
      "W_value: torch.Size([4096, 4096])\n",
      "W_query: torch.Size([4096, 4096])\n"
     ]
    }
   ],
   "source": [
    "# Settings\n",
    "batch_size = 1\n",
    "context_len = 3000\n",
    "max_context_len = 8192\n",
    "embed_dim = 4096\n",
    "num_heads = 32\n",
    "\n",
    "\n",
    "example_batch = torch.randn((batch_size, context_len, embed_dim))\n",
    "\n",
    "mha = MultiHeadAttention(\n",
    "    d_in=embed_dim,\n",
    "    d_out=embed_dim,\n",
    "    context_length=max_context_len,\n",
    "    num_heads=num_heads\n",
    ")\n",
    "\n",
    "mha(example_batch)\n",
    "\n",
    "print(\"W_key:\", mha.W_key.weight.shape)\n",
    "print(\"W_value:\", mha.W_value.weight.shape)\n",
    "print(\"W_query:\", mha.W_query.weight.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "IMQtFkcQ9sXC",
   "metadata": {
    "id": "IMQtFkcQ9sXC"
   },
   "source": [
    "- Now, if we use grouped-query attention instead, with 8 kv-groups (that's how many Llama 3 8B uses), we can see that the number of rows of the key and value matrices are reduced by a factor of 4 (because 32 attention heads divided by 8 kv-groups is 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb",
    "outputId": "69709a78-2aaa-4597-8142-2f44eb59753f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W_key: torch.Size([1024, 4096])\n",
      "W_value: torch.Size([1024, 4096])\n",
      "W_query: torch.Size([4096, 4096])\n"
     ]
    }
   ],
   "source": [
    "gqa = GroupedQueryAttention(\n",
    "    d_in=embed_dim,\n",
    "    d_out=embed_dim,\n",
    "    context_length=max_context_len,\n",
    "    num_heads=num_heads,\n",
    "    num_kv_groups=8,\n",
    "    rope_base=llama_3_theta_base\n",
    ")\n",
    "\n",
    "gqa(example_batch)\n",
    "\n",
    "print(\"W_key:\", gqa.W_key.weight.shape)\n",
    "print(\"W_value:\", gqa.W_value.weight.shape)\n",
    "print(\"W_query:\", gqa.W_query.weight.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5",
   "metadata": {
    "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5"
   },
   "source": [
    "- As a side note, to make the GroupedQueryAttention equivalent to standard multi-head attention, you can set the number of query groups (`num_kv_groups`) equal to the number of heads (`num_heads`)\n",
    "- Lastly, let's compare the number of parameters below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "58f713aa-ac00-4e2f-8247-94609aa01350",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "58f713aa-ac00-4e2f-8247-94609aa01350",
    "outputId": "486dfd9c-9f3a-4b9e-f9a2-35fb43b9a5fb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters:\n",
      "MHA: 67,108,864\n",
      "GQA: 41,943,040\n"
     ]
    }
   ],
   "source": [
    "print(\"Total number of parameters:\")\n",
    "\n",
    "mha_total_params = sum(p.numel() for p in mha.parameters())\n",
    "print(f\"MHA: {mha_total_params:,}\")\n",
    "\n",
    "gqa_total_params = sum(p.numel() for p in gqa.parameters())\n",
    "print(f\"GQA: {gqa_total_params:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5",
   "metadata": {
    "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5"
   },
   "outputs": [],
   "source": [
    "# Free up memory:\n",
    "del mha\n",
    "del gqa"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9",
   "metadata": {
    "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9"
   },
   "source": [
    "&nbsp;\n",
    "## 1.4 Update the TransformerBlock module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "KABNccft_YnR",
   "metadata": {
    "id": "KABNccft_YnR"
   },
   "source": [
    "- Next, we update the `TransformerBlock`\n",
    "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4",
   "metadata": {
    "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4"
   },
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.att =  GroupedQueryAttention(  # MultiHeadAttention(\n",
    "            d_in=cfg[\"emb_dim\"],\n",
    "            d_out=cfg[\"emb_dim\"],\n",
    "            context_length=cfg[\"context_length\"],\n",
    "            num_heads=cfg[\"n_heads\"],\n",
    "            num_kv_groups=cfg[\"n_kv_groups\"],  # NEW\n",
    "            rope_base=cfg[\"rope_base\"],        # NEW\n",
    "            rope_config=cfg[\"rope_freq\"],      # NEW\n",
    "            dtype=cfg[\"dtype\"]\n",
    "        )\n",
    "        self.ff = FeedForward(cfg)\n",
    "        self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
    "        self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Shortcut connection for attention block\n",
    "        shortcut = x\n",
    "        x = self.norm1(x)\n",
    "        x = self.att(x.to(torch.bfloat16))   # Shape [batch_size, num_tokens, emb_size]\n",
    "        x = x + shortcut  # Add the original input back\n",
    "\n",
    "        # Shortcut connection for feed-forward block\n",
    "        shortcut = x\n",
    "        x = self.norm2(x)\n",
    "        x = self.ff(x.to(torch.bfloat16))\n",
    "        x = x + shortcut  # Add the original input back\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9",
   "metadata": {
    "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9"
   },
   "source": [
    "&nbsp;\n",
    "## 1.5 Defining the model class"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "M_tLAq_r_llN",
   "metadata": {
    "id": "M_tLAq_r_llN"
   },
   "source": [
    "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6",
   "metadata": {
    "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6"
   },
   "outputs": [],
   "source": [
    "# class Llama2Model(nn.Module):\n",
    "class Llama3Model(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
    "\n",
    "        self.trf_blocks = nn.Sequential(\n",
    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
    "\n",
    "        self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
    "\n",
    "    def forward(self, in_idx):\n",
    "        tok_embeds = self.tok_emb(in_idx)\n",
    "        x = tok_embeds\n",
    "        x = self.trf_blocks(x)\n",
    "        x = self.final_norm(x)\n",
    "        logits = self.out_head(x.to(torch.bfloat16))\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60",
   "metadata": {
    "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60"
   },
   "source": [
    "&nbsp;\n",
    "## 2. Initialize model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HoGGRAGykQTE",
   "metadata": {
    "id": "HoGGRAGykQTE"
   },
   "source": [
    "- Now we can define a Llama 3 config file (the Llama 2 config file is shown for comparison)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18",
   "metadata": {
    "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18"
   },
   "outputs": [],
   "source": [
    "LLAMA2_CONFIG_7B = {\n",
    "    \"vocab_size\": 32_000,    # Vocabulary size\n",
    "    \"context_length\": 4096,  # Context length\n",
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
    "    \"n_heads\": 32,           # Number of attention heads\n",
    "    \"n_layers\": 32,          # Number of layers\n",
    "    \"hidden_dim\": 11_008,    # Size of the intermediate dimension in FeedForward\n",
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2ad90f82-15c7-4806-b509-e45b56f57db5",
   "metadata": {
    "id": "2ad90f82-15c7-4806-b509-e45b56f57db5"
   },
   "outputs": [],
   "source": [
    "LLAMA3_CONFIG_8B = {\n",
    "    \"vocab_size\": 128_256,   # NEW: Larger vocabulary size\n",
    "    \"context_length\": 8192,  # NEW: Larger context length\n",
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
    "    \"n_heads\": 32,           # Number of attention heads\n",
    "    \"n_layers\": 32,          # Number of layers\n",
    "    \"hidden_dim\": 14_336,    # NEW: Larger size of the intermediate dimension in FeedForward\n",
    "    \"n_kv_groups\": 8,        # NEW: Key-Value groups for grouped-query attention\n",
    "    \"rope_base\": 500_000.0,  # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
    "    \"rope_freq\": None,       # NEW: Additional configuration for adjusting the RoPE frequencies\n",
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "FAP7fiBzkaBz",
   "metadata": {
    "id": "FAP7fiBzkaBz"
   },
   "source": [
    "- Using these settings, we can now initialize a Llama 3 8B model\n",
    "- Note that this requires ~34 GB of memory (for comparison, Llama 2 7B required ~26 GB of memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7004d785-ac9a-4df5-8760-6807fc604686",
   "metadata": {
    "id": "7004d785-ac9a-4df5-8760-6807fc604686"
   },
   "outputs": [],
   "source": [
    "model = Llama3Model(LLAMA3_CONFIG_8B)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
   "metadata": {},
   "source": [
    "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check buffers\n",
    "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
    "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
    "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8056a521-91a6-440f-8473-591409c3177b",
   "metadata": {},
   "source": [
    "- Let's now also compute the number of trainable parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
    "outputId": "0a8cd23b-d9fa-4c2d-ca63-3fc79bc4de0d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters: 8,030,261,248\n"
     ]
    }
   ],
   "source": [
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"Total number of parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Bx14NtzWk2wj",
   "metadata": {
    "id": "Bx14NtzWk2wj"
   },
   "source": [
    "- As shown above, the model contains 8 billion parameters\n",
    "- Additionally, we can calculate the memory requirements for this model using the code below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
    "outputId": "3425e9ce-d8c0-4b37-bded-a2c60b66a41a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "float32 (PyTorch default): 68.08 GB\n",
      "bfloat16: 34.04 GB\n"
     ]
    }
   ],
   "source": [
    "def model_memory_size(model, input_dtype=torch.float32):\n",
    "    total_params = 0\n",
    "    total_grads = 0\n",
    "    for param in model.parameters():\n",
    "        # Calculate total number of elements per parameter\n",
    "        param_size = param.numel()\n",
    "        total_params += param_size\n",
    "        # Check if gradients are stored for this parameter\n",
    "        if param.requires_grad:\n",
    "            total_grads += param_size\n",
    "\n",
    "    # Calculate buffer size (non-parameters that require memory)\n",
    "    total_buffers = sum(buf.numel() for buf in model.buffers())\n",
    "\n",
    "    # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
    "    # We assume parameters and gradients are stored in the same type as input dtype\n",
    "    element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
    "    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
    "\n",
    "    # Convert bytes to gigabytes\n",
    "    total_memory_gb = total_memory_bytes / (1024**3)\n",
    "\n",
    "    return total_memory_gb\n",
    "\n",
    "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
    "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "zudd-5PulKFL",
   "metadata": {
    "id": "zudd-5PulKFL"
   },
   "source": [
    "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d",
   "metadata": {
    "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34",
   "metadata": {
    "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34"
   },
   "source": [
    "&nbsp;\n",
    "## 3. Load tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005",
   "metadata": {
    "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005"
   },
   "source": [
    "- In this section, we are going to load the tokenizer for the model\n",
    "- Llama 2 used Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's BPE tokenizer based on the [Tiktoken](https://github.com/openai/tiktoken) library\n",
    "- Llama 3, however, reverted back to using the BPE tokenizer from Tiktoken; specifically, it uses the GPT-4 tokenizer with an extended vocabulary\n",
    "- You can find the original Tiktoken-adaptation by Meta AI [here](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py) in their official Llama 3 repository\n",
    "- Below, I rewrote the tokenizer code to make it more readable and minimal for this notebook (but the behavior should be similar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10",
   "metadata": {
    "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10"
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import tiktoken\n",
    "from tiktoken.load import load_tiktoken_bpe\n",
    "\n",
    "\n",
    "class Tokenizer:\n",
    "    def __init__(self, model_path):\n",
    "        assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
    "        mergeable_ranks = load_tiktoken_bpe(model_path)\n",
    "\n",
    "        self.special_tokens = {\n",
    "            \"<|begin_of_text|>\": 128000,\n",
    "            \"<|end_of_text|>\": 128001,\n",
    "            \"<|start_header_id|>\": 128006,\n",
    "            \"<|end_header_id|>\": 128007,\n",
    "            \"<|eot_id|>\": 128009,\n",
    "        }\n",
    "        self.special_tokens.update({\n",
    "            f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
    "        })\n",
    "\n",
    "        self.model = tiktoken.Encoding(\n",
    "            name=Path(model_path).name,\n",
    "            pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
    "            mergeable_ranks=mergeable_ranks,\n",
    "            special_tokens=self.special_tokens\n",
    "        )\n",
    "\n",
    "\n",
    "    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
    "        if bos:\n",
    "            tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
    "        else:\n",
    "            tokens = []\n",
    "\n",
    "        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
    "\n",
    "        if eos:\n",
    "            tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
    "        return tokens\n",
    "\n",
    "    def decode(self, tokens):\n",
    "        return self.model.decode(tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a1509f8-8778-4fec-ba32-14d95c646167",
   "metadata": {
    "id": "0a1509f8-8778-4fec-ba32-14d95c646167"
   },
   "source": [
    "- Meta AI shared the original Llama 3 model weights and tokenizer vocabulary on the Hugging Face Hub\n",
    "- We will first download the tokenizer vocabulary from the Hub and load it into the code above"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "KbnlzsbYmJU6",
   "metadata": {
    "id": "KbnlzsbYmJU6"
   },
   "source": [
    "- Please note that Meta AI requires that you accept the Llama 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to accept the terms\n",
    "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
    "\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
    "\n",
    "- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3357a230-b678-4691-a238-257ee4e80185",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3357a230-b678-4691-a238-257ee4e80185",
    "outputId": "a3652def-ea7f-46fb-f293-2a59affb71a0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
      "Token is valid (permission: read).\n",
      "Your token has been saved to /root/.cache/huggingface/token\n",
      "Login successful\n"
     ]
    }
   ],
   "source": [
    "from huggingface_hub import login\n",
    "import json\n",
    "\n",
    "with open(\"config.json\", \"r\") as config_file:\n",
    "    config = json.load(config_file)\n",
    "    access_token = config[\"HF_ACCESS_TOKEN\"]\n",
    "\n",
    "login(token=access_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "IxGh6ZYQo0VN",
   "metadata": {
    "id": "IxGh6ZYQo0VN"
   },
   "source": [
    "- After login via the access token, which is necessary to verify that we accepted the Llama 3 licensing terms, we can now download the tokenizer vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
    "outputId": "c9836ba8-5176-4dd5-b618-6cc36fdbe1f0"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "tokenizer_file_path = hf_hub_download(\n",
    "    repo_id=\"meta-llama/Meta-Llama-3-8B\",\n",
    "    filename=\"original/tokenizer.model\",\n",
    "    local_dir=\"Llama-3-8B\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "F8BH1Nk0AYCS",
   "metadata": {
    "id": "F8BH1Nk0AYCS"
   },
   "source": [
    "- Note that for using Llama 3 files, we may need the `blobfile` package, which is used when handling datasets or models stored in cloud storage solutions like Google Cloud Storage (GCS), Azure Blob Storage, or Amazon S3\n",
    "- You can install this dependency by uncommenting and executing the `pip` command below\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5dm6Oz7uAytV",
   "metadata": {
    "id": "5dm6Oz7uAytV"
   },
   "outputs": [],
   "source": [
    "# pip install blobfile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0",
   "metadata": {
    "id": "8b8c0ce6-a6fb-4b8a-8de2-ee7bb7646fd0"
   },
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer(tokenizer_file_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "NVhmFeX3pT_M",
   "metadata": {
    "id": "NVhmFeX3pT_M"
   },
   "source": [
    "- We can now use the `generate` function to have the Llama 3 model generate new text:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
    "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%，.Reader(\",\");\n",
      "ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n"
     ]
    }
   ],
   "source": [
    "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n",
    "# If the `previous_chapters.py` file is not available locally,\n",
    "# you can import it from the `llms-from-scratch` PyPI package.\n",
    "# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
    "# E.g.,\n",
    "# from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text\n",
    "\n",
    "\n",
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
    "    max_new_tokens=30,\n",
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93WTtAA5paYV",
   "metadata": {
    "id": "93WTtAA5paYV"
   },
   "source": [
    "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 3 model yet\n",
    "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f63cc248-1d27-4eb6-aa50-173b436652f8",
   "metadata": {
    "id": "f63cc248-1d27-4eb6-aa50-173b436652f8"
   },
   "source": [
    "&nbsp;\n",
    "## 4. Load pretrained weights"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aKeN7rUfqZMI",
   "metadata": {
    "id": "aKeN7rUfqZMI"
   },
   "source": [
    "- We are loading the [\"meta-llama/Meta-Llama-3-8B\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B) base model below, which is a simple text completion model before finetuning\n",
    "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Meta-Llama-3-8B-Instruct\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model by modifying the string in the next code cell accordingly\n",
    "- Combined, the weight files are about 16 GB large"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 145,
     "referenced_widgets": [
      "f3788acce34f4956b0727b58d0cf38c6",
      "6022a9426683420690d9b41a0ca4f870",
      "e9aba3d53b4d45c485a7aad649c7b465",
      "f1a12d7929db4309b9881853135359fc",
      "58c9dec75a3346b1b787f88dd510d254",
      "9492edc02dee456f840325d913fa4e4f",
      "66dc94b23556499f985f8accbb1f89cb",
      "7c6658cfff1a4d27af3de148184f77d9",
      "7266a729edfb4a44b5b1c67dc79be146",
      "76dbab4873f342019c5d7624ae2c9775",
      "3cea4b431147441a8d9bd872811d5974",
      "8ae98969541849efa356cf912ac39b1e",
      "f9373112649945e3b446c3e1ec274dc1",
      "d49791082a304ade95c185c79fae1f41",
      "616e383bb3d442bcb6edb2721a8180b6",
      "87f474861e54432e9d533e0a89bb77da",
      "e805bb6dfee34dab8870f4618d8bffdb",
      "be3e9bf271f04eb0b119659e1af3a0ea",
      "00148825ce0248b7a23eb28e3eca6749",
      "f1a9b0c2431640298a6c1b258298b12d",
      "8ba9f009e92a46fcbcbb401dc444f12e",
      "d74186bb74d142dfb683fa347b6990f7",
      "9bb60a5a3710463ebe3a17f8d2a446be",
      "0a08fb81165748748ccb080e6df0600f",
      "603690f543114a7fb6aebd433c80bdc3",
      "773b802daed942f5a11f3eab3b83be08",
      "7989003a613e45f780d3f800e121543a",
      "9d49589118f5432cac49650251046429",
      "f114549fe8ce49638a791ca2fecb2d89",
      "0aa155b794a8426aa265f4a7670f43ad",
      "a06fbde549cc47fdaddfbdb82d35d823",
      "172c0c6955e1428b999dcb2d133704cd",
      "1bf7108774c34016a2193e2cd7639b7d",
      "ed28e180d94a4b7aa548581612e31232",
      "ff4338faded5494da1ccb660e1c441ed",
      "b46a08cf4929422eb0f76d8d9af11249",
      "f049eb4a50f54c34912ca959d2eaf353",
      "80dfd3e80ceb444a83ec1fd65f9af80e",
      "519147a10b984befbd0f255f78c1f66a",
      "562e82438dbe41b793ff488b8447c5bf",
      "1da83719e47c4196b06f3aa32056b560",
      "c4a2c88326d14fbca87cfde073755a2e",
      "f0ab5a46cbb0444c88ed137d8a95002b",
      "f8f28ac0e149428f9fef42373c6a87d0"
     ]
    },
    "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
    "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "245443330e4d40c887a5649cc1663e98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from safetensors.torch import load_file\n",
    "\n",
    "combined_weights = {}\n",
    "\n",
    "for i in range(1, 5):\n",
    "    weights_file = hf_hub_download(\n",
    "        repo_id=\"meta-llama/Meta-Llama-3-8B\",\n",
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
    "        local_dir=\"Llama-3-8B\"\n",
    "    )\n",
    "    current_weights = load_file(weights_file)\n",
    "    combined_weights.update(current_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "-15SJ7btq2zE",
   "metadata": {
    "id": "-15SJ7btq2zE"
   },
   "source": [
    "- The `weights` contains the following tensors (only the first 15 are shown for simplicity):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
    "outputId": "2fbc2786-677f-4fea-9472-5fb8542ff14b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['model.embed_tokens.weight',\n",
       " 'model.layers.0.input_layernorm.weight',\n",
       " 'model.layers.0.mlp.down_proj.weight',\n",
       " 'model.layers.0.mlp.gate_proj.weight',\n",
       " 'model.layers.0.mlp.up_proj.weight',\n",
       " 'model.layers.0.post_attention_layernorm.weight',\n",
       " 'model.layers.0.self_attn.k_proj.weight',\n",
       " 'model.layers.0.self_attn.o_proj.weight',\n",
       " 'model.layers.0.self_attn.q_proj.weight',\n",
       " 'model.layers.0.self_attn.v_proj.weight',\n",
       " 'model.layers.1.input_layernorm.weight',\n",
       " 'model.layers.1.mlp.down_proj.weight',\n",
       " 'model.layers.1.mlp.gate_proj.weight',\n",
       " 'model.layers.1.mlp.up_proj.weight',\n",
       " 'model.layers.1.post_attention_layernorm.weight']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(combined_weights.keys())[:15]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UeeSpnunrDFB",
   "metadata": {
    "id": "UeeSpnunrDFB"
   },
   "source": [
    "- The following function, modeled after the `load_weights_into_gpt` function in [chapter 5](../01_main-chapter-code/ch05.ipynb), loads the pretrained weights into our Llama 3 model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
   "metadata": {
    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
   },
   "outputs": [],
   "source": [
    "def assign(left, right, tensor_name=\"unknown\"):\n",
    "    if left.shape != right.shape:\n",
    "        raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n",
    "\n",
    "    if isinstance(right, torch.Tensor):\n",
    "        return torch.nn.Parameter(right.clone().detach())\n",
    "    else:\n",
    "        return torch.nn.Parameter(torch.tensor(right))\n",
    "\n",
    "\n",
    "def load_weights_into_llama(model, param_config, params):\n",
    "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
    "\n",
    "    for l in range(param_config[\"n_layers\"]):\n",
    "\n",
    "        # Load attention weights\n",
    "        model.trf_blocks[l].att.W_query.weight = assign(\n",
    "            model.trf_blocks[l].att.W_query.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.q_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].att.W_key.weight = assign(\n",
    "            model.trf_blocks[l].att.W_key.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.k_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].att.W_value.weight = assign(\n",
    "            model.trf_blocks[l].att.W_value.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.v_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].att.out_proj.weight = assign(\n",
    "            model.trf_blocks[l].att.out_proj.weight,\n",
    "            params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
    "            f\"model.layers.{l}.self_attn.o_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].norm1.weight = assign(\n",
    "            model.trf_blocks[l].norm1.weight,\n",
    "            params[f\"model.layers.{l}.input_layernorm.weight\"],\n",
    "            f\"model.layers.{l}.input_layernorm.weight\"\n",
    "        )\n",
    "\n",
    "        # Load FeedForward weights\n",
    "        model.trf_blocks[l].ff.fc1.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc1.weight,\n",
    "            params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.gate_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].ff.fc2.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc2.weight,\n",
    "            params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.up_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].ff.fc3.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc3.weight,\n",
    "            params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
    "            f\"model.layers.{l}.mlp.down_proj.weight\"\n",
    "        )\n",
    "        model.trf_blocks[l].norm2.weight = assign(\n",
    "            model.trf_blocks[l].norm2.weight,\n",
    "            params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
    "            f\"model.layers.{l}.post_attention_layernorm.weight\"\n",
    "        )\n",
    "\n",
    "    # Load output layer weights\n",
    "    model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n",
    "\n",
    "    if \"lm_head.weight\" in params.keys():\n",
    "        model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n",
    "    else:\n",
    "        model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n",
    "        print(\"Model uses weight tying.\")\n",
    "\n",
    "\n",
    "load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)\n",
    "model.to(device);\n",
    "del combined_weights  # free up memory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "TDuv_Us2rNvk",
   "metadata": {
    "id": "TDuv_Us2rNvk"
   },
   "source": [
    "- Next, we are ready to use the model for text generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "240987e8-a023-462e-9376-9edfb27559ec",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "240987e8-a023-462e-9376-9edfb27559ec",
    "outputId": "6dab0e56-40a8-45db-a096-ab2b9ee97a69"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
    "    max_new_tokens=25,\n",
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1203041e-4794-4157-a978-3ce80909da44",
   "metadata": {
    "id": "1203041e-4794-4157-a978-3ce80909da44"
   },
   "source": [
    "&nbsp;\n",
    "## 5. Using the instruction-finetuned model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "akyo7WNyF_YL",
   "metadata": {
    "id": "akyo7WNyF_YL"
   },
   "source": [
    "- Above, we used the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-3-8B-Instruct\"` model instead, as shown below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "hdA-xjjdS26J",
   "metadata": {
    "id": "hdA-xjjdS26J"
   },
   "outputs": [],
   "source": [
    "# to free up memory\n",
    "\n",
    "import gc\n",
    "\n",
    "del model\n",
    "\n",
    "gc.collect()  # Run Python garbage collector\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "nbvAV7vaz6yc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 145,
     "referenced_widgets": [
      "409470784b6346a981920350de4f6f28",
      "9ba6a11ffd194bf9a0900f52a7ed4d4f",
      "acae8bbbb4a84ed49be72fecd11fb052",
      "e8a4b441281b4038bb0204d093411f68",
      "bdf8b693821344fc97918e6cbc31c8bf",
      "97e8877869cd4be68ff38ce745be5045",
      "cc3da88e93c4499993b7bbb7d3064326",
      "0d51fdc2c416474da04079db6579890f",
      "c4598300a77b4667b1117f9499f5ccb7",
      "77606cd2fe1b4d33a91ede944bb1dec0",
      "f1ba439c26d64c90af2f162c74348405",
      "d598f094c3ce4daeab19fac8094cba7e",
      "0afc2d23514b45c9890b5d2ee4e6fa0b",
      "3da5d38bf3314d3eaa7cedebae41c076",
      "55e6b727a4594078beb3853cc1891308",
      "f17fa78263414ef8b414c7bf3ac03192",
      "e8b187b40ec14db3af17a380830a35bf",
      "e94ca32eaa9f4714a3b05a5fdf24d02b",
      "3edd464991204b8690eae02f10b4cc00",
      "ac1e34f4bd6c420bb6cc2fdde5f3ed4d",
      "1cd5e07cad35450182004952de32c8e7",
      "a63351a6715643378491ba831b3fb05d",
      "98b4680141ee423bb5e43c47613d8440",
      "b02ffefca3f34252914e76f4a8a467dc",
      "31d27bf34a74432f8e0dbfe9ecb76130",
      "a3137f3669b54e84be91010c9654d985",
      "5a2886564d3f40ceaa30b743dbe81f45",
      "15ea8fcfe097471e8fc9502a162f5904",
      "c779e80c50ba4434bfa1d326c5cc9b0f",
      "eb94612785e64552aea8674dc8647a93",
      "279cffe683fe4e7383062162e07ed9ed",
      "6176990205cc499f8995c71fc6b9d4df",
      "66c23ae98bcc45f18fc5c91e0e73c3e4",
      "05b502e1e3a9436297dafbb1ce7af722",
      "25977b0d89084703ad787fe9208b5aad",
      "71a84ee5fc964ec89ff2832c84735cc2",
      "6aed783eccb942318e6384e253ad4924",
      "84c34bfecda64391a609e19f131d51d4",
      "20ecac7c646b45938ed393cb20977c37",
      "ebe04aeaaac042aaaa0885992e45793d",
      "ca81071ab07446df96795a482ce0c630",
      "e0550cab24c7492787af40dc4b8576bf",
      "7015bf6f85954036aaf8cc4f1c44ea0f",
      "2a2ba3d065634484a932b8d3c212af56"
     ]
    },
    "id": "nbvAV7vaz6yc",
    "outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7df6bbf8e63448c8a6cb5d2f6208403",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00004.safetensors:  36%|###6      | 1.81G/4.98G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4772f31a1c5b4c168c9aabe7a1d2bacc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad49eeb9e1204ea2bd2e371df8ccdea2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "951b9e81613a40a2a503f61e69677f0a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "combined_weights = {}\n",
    "\n",
    "for i in range(1, 5):\n",
    "    weights_file = hf_hub_download(\n",
    "        repo_id=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
    "        local_dir=\"Llama-3-8B-Instruct\"\n",
    "    )\n",
    "    current_weights = load_file(weights_file)\n",
    "    combined_weights.update(current_weights)\n",
    "\n",
    "\n",
    "model = Llama3Model(LLAMA3_CONFIG_8B)\n",
    "load_weights_into_llama(model, LLAMA3_CONFIG_8B, combined_weights)\n",
    "model.to(device)\n",
    "del combined_weights  # free up memory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "VlH7qYVdDKQr",
   "metadata": {
    "id": "VlH7qYVdDKQr"
   },
   "source": [
    "- Note that the Llama 3 model should ideally be used with the correct prompt template that was used during finetuning (as discussed in chapter 7)\n",
    "- Below is a wrapper class around the tokenizer based on Meta AI's Llama 3-specific [ChatFormat code](https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/tokenizer.py#L202) that constructs the prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "4be5b481-1110-46e8-a931-3988d890cf8c",
   "metadata": {
    "id": "4be5b481-1110-46e8-a931-3988d890cf8c"
   },
   "outputs": [],
   "source": [
    "class ChatFormat:\n",
    "    def __init__(self, tokenizer):\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def encode_header(self, message):\n",
    "        tokens = []\n",
    "        tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
    "        tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
    "        tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
    "        tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
    "        return tokens\n",
    "\n",
    "    def encode(self, text):\n",
    "        message = {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": text\n",
    "        }\n",
    "\n",
    "        tokens = self.encode_header(message)\n",
    "        tokens.extend(\n",
    "            self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
    "        )\n",
    "        tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
    "        return tokens\n",
    "\n",
    "    def decode(self, token_ids):\n",
    "        return self.tokenizer.decode(token_ids)\n",
    "\n",
    "\n",
    "chat_tokenizer = ChatFormat(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "M-dkSNvwDttN",
   "metadata": {
    "id": "M-dkSNvwDttN"
   },
   "source": [
    "- The usage is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "nwBrTGTsUNhn",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nwBrTGTsUNhn",
    "outputId": "72a495b4-b872-429a-88ef-49a9b4577f0f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n"
     ]
    }
   ],
   "source": [
    "token_ids = chat_tokenizer.encode(\"Hello World!\")\n",
    "print(token_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "0fpmpVgYVTRZ",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 36
    },
    "id": "0fpmpVgYVTRZ",
    "outputId": "bb3e819a-112a-466c-ac51-5d14a9c3475b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Wo-aUGeKDvqq",
   "metadata": {
    "id": "Wo-aUGeKDvqq"
   },
   "source": [
    "- Let's now see the Llama 3 instruction model in action:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "ozGOBu6XOkEW",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ozGOBu6XOkEW",
    "outputId": "4f689c70-bed9-46f3-a52a-aea47b641283"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n",
      "\n",
      "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n",
      "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n",
      "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n",
      "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"What do llamas eat?\", chat_tokenizer).to(device),\n",
    "    max_new_tokens=150,\n",
    "    context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "output_text = token_ids_to_text(token_ids, tokenizer)\n",
    "\n",
    "\n",
    "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n",
    "    # Find the index of the first occurrence of \"<|end_header_id|>\"\n",
    "    index = text.find(header_end)\n",
    "\n",
    "    if index != -1:\n",
    "        # Return the substring starting after \"<|end_header_id|>\"\n",
    "        return text[index + len(header_end):].strip()  # Strip removes leading/trailing whitespace\n",
    "    else:\n",
    "        # If the token is not found, return the original text\n",
    "        return text\n",
    "\n",
    "print(\"Output text:\\n\", clean_text(output_text))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2r5JKrO-ZOHK",
   "metadata": {
    "id": "2r5JKrO-ZOHK"
   },
   "source": [
    "&nbsp;\n",
    "# Llama 3.1 8B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "QiQxX0XnP_iC",
   "metadata": {
    "id": "QiQxX0XnP_iC"
   },
   "source": [
    "- A few months after the initial Llama 3 release, Meta AI followed up with their Llama 3.1 suite of models (see the official [Introducing Llama 3.1: Our most capable models to date](https://ai.meta.com/blog/meta-llama-3-1/) announcement blog post for details)\n",
    "- Conveniently, we can reuse our previous Llama 3 code from above to implement Llama 3.1 8B\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama3-to-llama31.webp\" width=\"700px\">\n",
    "\n",
    "- The architecture is identical, with the only change being a rescaling of the RoPE frequencies as indicated in the configuration file below\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "X5Fg8XUHMv4M",
   "metadata": {
    "id": "X5Fg8XUHMv4M"
   },
   "outputs": [],
   "source": [
    "LLAMA3_CONFIG_8B = {\n",
    "    \"vocab_size\": 128_256,   # Vocabulary size\n",
    "    \"context_length\": 8192,  # Context length\n",
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
    "    \"n_heads\": 32,           # Number of attention heads\n",
    "    \"n_layers\": 32,          # Number of layers\n",
    "    \"hidden_dim\": 14_336,    # Size of the intermediate dimension in FeedForward\n",
    "    \"n_kv_groups\": 8,        # Key-Value groups for grouped-query attention\n",
    "    \"rope_base\": 500_000.0,  # The base in RoPE's \"theta\"\n",
    "    \"rope_freq\": None,       # Additional configuration for adjusting the RoPE frequencies\n",
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
    "}\n",
    "\n",
    "LLAMA31_CONFIG_8B = {\n",
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
    "    \"context_length\": 131_072,  # NEW: Larger supported context length\n",
    "    \"emb_dim\": 4096,            # Embedding dimension\n",
    "    \"n_heads\": 32,              # Number of attention heads\n",
    "    \"n_layers\": 32,             # Number of layers\n",
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
    "        \"factor\": 8.0,\n",
    "        \"low_freq_factor\": 1.0,\n",
    "        \"high_freq_factor\": 4.0,\n",
    "        \"original_context_length\": 8192,\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
   "metadata": {},
   "source": [
    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
   "metadata": {
    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New RoPE theta: 31250.0\n"
     ]
    }
   ],
   "source": [
    "old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
    "LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
    "\n",
    "\n",
    "def rescale_theta(theta_old, context_length_old, context_length_new):\n",
    "    scaling_factor = context_length_new / context_length_old\n",
    "    theta_new = theta_old * scaling_factor\n",
    "    return theta_new\n",
    "\n",
    "LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
    "    LLAMA31_CONFIG_8B[\"rope_base\"],\n",
    "    old_context_length,\n",
    "    LLAMA31_CONFIG_8B[\"context_length\"]\n",
    ")\n",
    "\n",
    "print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "xa3bpMDtTdBs",
   "metadata": {
    "id": "xa3bpMDtTdBs"
   },
   "source": [
    "- As we've seen in the code earlier, the RoPE method uses sinusoidal functions (sine and cosine) to embed positional information directly into the attention mechanism\n",
    "- In Llama 3.1, via the additional configuration, we introduce additional adjustments to the inverse frequency calculations\n",
    "- These adjustments influence how different frequency components contribute to the positional embeddings (a detailed explanation is a topic for another time)\n",
    "- Let's try out the Llama 3.1 model in practice; first, we clear out the old model to free up some GPU memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "7dUtYnNUOqhL",
   "metadata": {
    "id": "7dUtYnNUOqhL"
   },
   "outputs": [],
   "source": [
    "# free up memory\n",
    "del model\n",
    "\n",
    "gc.collect()  # Run Python garbage collector\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "DbbVsll6TYWR",
   "metadata": {
    "id": "DbbVsll6TYWR"
   },
   "source": [
    "- Next, we download the tokenizer\n",
    "- Note that since the Llama 3.1 family is distinct from the Llama 3 family, you'd have to go to the [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n",
    "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.1-8B\"` with `\"meta-llama/Llama-3.1-8B-Instruct\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8xDk4chtPNU4",
   "metadata": {
    "id": "8xDk4chtPNU4"
   },
   "outputs": [],
   "source": [
    "tokenizer_file_path = hf_hub_download(\n",
    "    repo_id=\"meta-llama/Llama-3.1-8B\",\n",
    "    filename=\"original/tokenizer.model\",\n",
    "    local_dir=\"Llama-3.1-8B\"\n",
    ")\n",
    "\n",
    "tokenizer = Tokenizer(tokenizer_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "a7l21VE4Otcs",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "a7l21VE4Otcs",
    "outputId": "3dd5cfba-bf3f-44d2-9be1-7cd42bfe4ba9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters: 8,030,261,248\n"
     ]
    }
   ],
   "source": [
    "model = Llama3Model(LLAMA31_CONFIG_8B)\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"Total number of parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "u4J7IxOvOyPM",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 145,
     "referenced_widgets": [
      "5bbaa046d8934c8fae0a12c3d7bd991b",
      "e1e4125eac004bae92dc1f22f673bf0e",
      "d5b4bb4891ec4e44be46e9815c7e10dc",
      "4f6595a392b244bd8e887935defc06f0",
      "100c1b15cc4046cea1147f657eb2d8d0",
      "81458e7953a349cfafccaa213b370406",
      "a3dc9dfadae642b4a873705596739468",
      "f55b59efcefa4ad5955d082f4bf7c637",
      "1b02e0c7d1604b1c87a327c4c4f8b0e7",
      "02ad170019454fd096b37347de5c481d",
      "c52e0f34892b4daa84c1bf61500ac399",
      "af985cf6fa26475eb2c4dd81e0c79ff4",
      "8659c3eddb014c3bb5931fd9e6fadad8",
      "f5fa00d96c4c49e48e1806d23a5b8570",
      "080c484114f64f5591fa1287a35b46c9",
      "14dc6a3717484c55a116612e28447dbb",
      "00d3286c9c1d4161bb777b7b65ae744d",
      "66f27fb11edf453b8144c2dfcdc66baa",
      "5798e5118430439fb1f6bf29e1bafe58",
      "357f367cf74146b8825be371acd51d06",
      "94073be250cd42d5b82e196e30cbf22e",
      "0cd0724f825e480389a82f0c49f91e6d",
      "dffa208978f34e6a9aae94ecda92fe67",
      "b8a98f163ebd4ac89af08a49c0881c23",
      "f0d9febe1a634a0ba7e8e50fa104dcc2",
      "e23870f0c7ff40cc8fa6a1e862a4af99",
      "87da9905a0534c26ad0712ad426ca930",
      "b953419300604b8e86fc0ad003fdfd2f",
      "f1865ed0fbcc40eeabdca90a43d00069",
      "ea0128909a9d4801ba312a876b0cf183",
      "d160986df978416c9ad91d1e10fc90fc",
      "5e97f7c2e8f5453dafcdad0552060e60",
      "4b3e7b8774df4b458bb6c6146fe3226d",
      "2ffd8dbed00e46d2887b9a2590cad297",
      "a06dcb3bdfc84905a7222066c32fe500",
      "e7602abc26714ee890a0cf5c0c7b67e1",
      "dc5d555099f64a998514ebde90eeb6df",
      "ef93a2f58cc54373941f43658bb808cf",
      "fea1e2327d2944859af3d91c216b9008",
      "320c00a5d18c45ccae634d166f1bd810",
      "6c857e69d5204cd3b7c3bf426993ad1f",
      "2145e47428f1446fba3e62b3cde0a7f5",
      "3d519ce3562c4e249bf392c7f43d04c0",
      "cc20ffcf0c1a4656945959bf457dfd84"
     ]
    },
    "id": "u4J7IxOvOyPM",
    "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eabfde3ef38b436ea750e6fb50a02b5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e117ad45771747ae95c16f9876e6dc19",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "170185f2f046437dab57c2ad23163c5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e65f5d6c5af4ab78bc7b3778b98ef86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "combined_weights = {}\n",
    "\n",
    "for i in range(1, 5):\n",
    "    weights_file = hf_hub_download(\n",
    "        repo_id=\"meta-llama/Llama-3.1-8B\",\n",
    "        filename=f\"model-0000{i}-of-00004.safetensors\",\n",
    "        local_dir=\"Llama-3.1-8B\"\n",
    "    )\n",
    "    current_weights = load_file(weights_file)\n",
    "    combined_weights.update(current_weights)\n",
    "\n",
    "load_weights_into_llama(model, LLAMA31_CONFIG_8B, combined_weights)\n",
    "model.to(device);\n",
    "del combined_weights  # free up memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "wJFnF8ATPbtD",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wJFnF8ATPbtD",
    "outputId": "67d5cb66-3588-4fd4-ac75-39bfe3aa82d8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort has been made to trace copyright holders and to obtain their permission for the use of copyright material. The publisher apologizes for any\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
    "    max_new_tokens=25,\n",
    "    context_size=LLAMA31_CONFIG_8B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "DR9NBDUjPrDp",
   "metadata": {
    "id": "DR9NBDUjPrDp"
   },
   "source": [
    "&nbsp;\n",
    "# Llama 3.2 1B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "imoxFiDzJcxk",
   "metadata": {
    "id": "imoxFiDzJcxk"
   },
   "source": [
    "- As of this writing, Meta AI's latest models are the Llama 3.2 models announced [here](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/)\n",
    "- The code for the Llama 3.2 text model is similar to that of Llama 3.1, except that the model has shrunk in size (there is a 1B and 3B version)\n",
    "- The other efficiency tweak was that they added back weight tying (a concept that was original used in the GPT-2 architecture); here, they reuse the same weight parameter values in the input (token) embedding layer and output layer\n",
    "- The small model size of Llama 3.2 1B is quite convenient, since it can even run on many mobile devices\n",
    "- The architectural differences between Llama 3.1 8B and Llama 3.2 1B are illustrated in the figure below"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "OL1EoXQ6TPb7",
   "metadata": {
    "id": "OL1EoXQ6TPb7"
   },
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/llama31-to-llama32.webp?1\" width=\"700px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "K0KgjwCCJ9Fb",
   "metadata": {
    "id": "K0KgjwCCJ9Fb"
   },
   "source": [
    "- As we can see based on the figure above, the main difference between the Llama 3.1 8B and Llama 3.2 1B architectures are the respective sizes\n",
    "- A small additional change is an increased RoPE rescaling factor, which is reflected in the configuration file below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "Yv_yF3NCQTBx",
   "metadata": {
    "id": "Yv_yF3NCQTBx"
   },
   "outputs": [],
   "source": [
    "LLAMA31_CONFIG_8B = {\n",
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
    "    \"context_length\": 131_072,  # NEW: Larger supported context length\n",
    "    \"emb_dim\": 4096,            # Embedding dimension\n",
    "    \"n_heads\": 32,              # Number of attention heads\n",
    "    \"n_layers\": 32,             # Number of layers\n",
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usagey\n",
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
    "        \"factor\": 8.0,\n",
    "        \"low_freq_factor\": 1.0,\n",
    "        \"high_freq_factor\": 4.0,\n",
    "        \"original_context_length\": 8192,\n",
    "    }\n",
    "}\n",
    "\n",
    "\n",
    "LLAMA32_CONFIG_1B = {\n",
    "    \"vocab_size\": 128_256,      # Vocabulary size\n",
    "    \"context_length\": 131_072,  # Context length\n",
    "    \"emb_dim\": 2048,            # NEW: Half the embedding dimension\n",
    "    \"n_heads\": 32,              # Number of attention heads\n",
    "    \"n_layers\": 16,             # NEW: Half the number of layers\n",
    "    \"hidden_dim\": 8192,         # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
    "    \"rope_freq\": {              # RoPE frequency scaling\n",
    "        \"factor\": 32.0,         # NEW: Adjustment of the rescaling factor\n",
    "        \"low_freq_factor\": 1.0,\n",
    "        \"high_freq_factor\": 4.0,\n",
    "        \"original_context_length\": 8192,\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
   "metadata": {},
   "source": [
    "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
   "metadata": {
    "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New RoPE theta: 31250.0\n"
     ]
    }
   ],
   "source": [
    "old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
    "LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
    "\n",
    "LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
    "    LLAMA32_CONFIG_1B[\"rope_base\"],\n",
    "    old_context_length,\n",
    "    LLAMA32_CONFIG_1B[\"context_length\"]\n",
    ")\n",
    "\n",
    "print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Dl4_0EoJKKYv",
   "metadata": {
    "id": "Dl4_0EoJKKYv"
   },
   "source": [
    "- Below, we can reuse the code from the Llama 3.1 8B section to load the Llama 3.2 1B model\n",
    "- Again, since the Llama 3.2 family is distinct from the Llama 3.1 family, you'd have to go to the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n",
    "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.2-1B\"` with `\"meta-llama/Llama-3.2-1B-Instruct\"`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "tCstHgyRRD2x",
   "metadata": {
    "id": "tCstHgyRRD2x"
   },
   "outputs": [],
   "source": [
    "# free up memory\n",
    "del model\n",
    "\n",
    "\n",
    "gc.collect()  # Run Python garbage collector\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "jt8BKAHXRCPI",
   "metadata": {
    "id": "jt8BKAHXRCPI"
   },
   "outputs": [],
   "source": [
    "tokenizer_file_path = hf_hub_download(\n",
    "    repo_id=\"meta-llama/Llama-3.2-1B\",\n",
    "    filename=\"original/tokenizer.model\",\n",
    "    local_dir=\"Llama-3.2-1B\"\n",
    ")\n",
    "\n",
    "tokenizer = Tokenizer(tokenizer_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "uf8KjasmRFSt",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "uf8KjasmRFSt",
    "outputId": "4e718852-2aa1-4b5a-bec3-3d5f866a4038"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters: 1,498,482,688\n",
      "\n",
      "Total number of unique parameters: 1,235,814,400\n"
     ]
    }
   ],
   "source": [
    "model = Llama3Model(LLAMA32_CONFIG_1B)\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"Total number of parameters: {total_params:,}\")\n",
    "\n",
    "# Account for weight tying\n",
    "total_params_normalized = total_params - model.tok_emb.weight.numel()\n",
    "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "9FbCIYW7RIOe",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "9FbCIYW7RIOe",
    "outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c309c56a6cdf426e8ba7967b6a21864e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model uses weight tying.\n"
     ]
    }
   ],
   "source": [
    "weights_file = hf_hub_download(\n",
    "    repo_id=\"meta-llama/Llama-3.2-1B\",\n",
    "    filename=\"model.safetensors\",\n",
    "    local_dir=\"Llama-3.2-1B\"\n",
    ")\n",
    "current_weights = load_file(weights_file)\n",
    "\n",
    "load_weights_into_llama(model, LLAMA32_CONFIG_1B, current_weights)\n",
    "model.to(device);\n",
    "del current_weights  # free up memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "pPp5yjir6FYJ",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "pPp5yjir6FYJ",
    "outputId": "6c8e79d2-0769-43a7-93b3-f04c030e1aac"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight tying: True\n"
     ]
    }
   ],
   "source": [
    "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "3kh7yrw2W4qr",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3kh7yrw2W4qr",
    "outputId": "b7e66a17-57ec-4b0e-c4ff-8d9a6b8e6ea5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
    "    max_new_tokens=25,\n",
    "    context_size=LLAMA32_CONFIG_1B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "VO4Qf0zyW1ZC",
   "metadata": {
    "id": "VO4Qf0zyW1ZC"
   },
   "source": [
    "&nbsp;\n",
    "# What's next?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "CjCewpo2XPAd",
   "metadata": {
    "id": "CjCewpo2XPAd"
   },
   "source": [
    "- This notebook concludes the conversion from GPT to Llama 3.2\n",
    "- If you are interested in a more compact, standalone notebook, which only contains the Llama 3.2 code, check out the [standalone-llama32.ipynb](standalone-llama32.ipynb) notebook"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
